# ======================================================================
# Copyright CERFACS (May 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

import numpy

from qaths.utils.constants import ABSOLUTE_TOLERANCE


def _real2str(num: float, decimals: int, atol: float, force_ones: bool) -> str:
    ret = ""
    float_format = "{0:." + str(decimals) + "f}"
    if force_ones or abs(num - 1) > max(10 ** (-decimals) / 2, atol):
        ret += float_format.format(num)
    return ret


def _complex2str(num: complex, decimals: int, atol: float = ABSOLUTE_TOLERANCE) -> str:
    ret = ""
    real, imag = abs(num.real), abs(num.imag)
    if real > atol:
        ret += _real2str(num.real, decimals, atol, force_ones=imag > atol)
        if imag > atol:
            ret += "+" if num.imag > 0 else "-"
    if imag > atol:
        ret += "i" + _real2str(imag, decimals, atol, force_ones=False)
    if real > atol and imag > atol:
        ret = "(" + ret + ")"
    return ret


def _num2str(num, decimals: int, atol: float = 1e-7) -> str:
    return _complex2str(complex(num), decimals, atol=atol)


def msb2str(num: int, length: int) -> str:
    return "|{msb_state}〉".format(msb_state=bin(num)[2:].zfill(length))


def msbqstate2str(
    statevector: numpy.ndarray, decimals: int = 2, atol: float = ABSOLUTE_TOLERANCE
) -> str:
    ret = ""
    n = numpy.ceil(numpy.log2(statevector.size)).astype(int)
    for i in range(len(statevector)):
        if abs(statevector[i]) > 10 ** (-decimals) / 2:
            ret += "{coeff}{msb_state} + ".format(
                coeff=_num2str(statevector[i], decimals, atol), msb_state=msb2str(i, n)
            )
    return ret[:-3]


def lsbqstate2str(
    statevector: numpy.ndarray, decimals: int = 2, atol: float = ABSOLUTE_TOLERANCE
) -> str:
    ret = ""
    n = numpy.ceil(numpy.log2(statevector.size)).astype(int)
    for i in range(len(statevector)):
        if abs(statevector[i]) > atol:
            ret += "{coeff}{msb_state} + ".format(
                coeff=_num2str(statevector[i], decimals, atol),
                msb_state=msb2str(i, n)[::-1],
            )
    return ret[:-2]


def differences(
    unitary1: numpy.ndarray,
    unitary2: numpy.ndarray,
    decimals: int = 2,
    msb: bool = True,
    atol: float = ABSOLUTE_TOLERANCE,
) -> str:
    n = numpy.round(numpy.log2(unitary1.shape[0])).astype(int)
    difference = numpy.abs(unitary1 - unitary2)
    ret = ""
    rows, cols = numpy.nonzero(difference)
    for i, j in zip(rows, cols):
        if difference[i, j] > atol:
            state_str = msb2str(i, n) + msb2str(j, n)
            if not msb:
                state_str = state_str[::-1]
            ret += "{msb_state} : {a} ≠ {b} with a difference of {diff}\n".format(
                msb_state=state_str,
                a=round(unitary1[i, j], decimals),
                b=round(unitary2[i, j], decimals),
                diff=difference[i, j],
            )

    return ret


def transformation(
    unitary: numpy.ndarray, decimals: int = 2, rows_to_print: numpy.ndarray = None
) -> str:
    if rows_to_print is None:
        rows_to_print = numpy.arange(unitary.shape[0])
    ret = ""
    n = numpy.round(numpy.log2(unitary.shape[0])).astype(int)
    for row_index in rows_to_print:
        ret += "{} → {}\n".format(
            msb2str(row_index, n), msbqstate2str(unitary[row_index], decimals=decimals)
        )
    return ret
